#!/bin/sh
_=[[
# vim: filetype=lua ts=2 sts=2 sw=2 et ai
if command -v luajit >/dev/null 2>&1; then
  exec luajit "$0" "$@"
fi
if command -v lua >/dev/null 2>&1; then
  exec lua "$0" "$@"
fi
echo 'requires lua or luajit in $PATH'
exit 1
]]

--------------------------------------------------------------------------------
-- The mekong driver driver. It provides gcc/nvcc "compatible" interface to
-- compile applications and apply mekong partitioning.
-- The design is inspired by the clang driver and divided into the
-- follwing steps:
-- 1. parse arguments
-- 2. construct actions to get to desired output using all inputs
-- 3. bind actions to filenames
-- 4. translate arguments to bindings (i.e. assemble commands)
-- 5. execute commands
--------------------------------------------------------------------------------

local app = (arg[0]):match(".-([^%/]+)$")
local script_dir = (arg[0]):match("(.-)[^%/]+$")
local script_dir = script_dir:match("(.-)/?$")
package.path = script_dir .. "/../lib/lua/?.lua"
local cargs = require('cargs')

CLANGXX = script_dir .. "/clang++"
CLANG = script_dir .. "/clang"
MEKONGRW = script_dir .. "/mekongrw"
LIBMEKONG = script_dir .. "/../lib/LLVMMekong@CMAKE_SHARED_LIBRARY_SUFFIX@"
MERTLIB = script_dir .. "/../lib/libmekongrt.bc"
MERTINC = script_dir .. "/../include/mekong"
VERSION = "@MEKONG_VERSION@"

CUDAARCH = "sm_30"
CUDAPATH = nil

--------------------------------------------------------------------------------
-- globals

local steps = {
  preprocess = true,
  analyze = true,
  rewrite = true,
  compile = true,
  link = true,
}

local arguments = nil
local opts = nil
local os_name = nil

local actions = {}
local bindings = {}
local cmds = {}

local temps = {}

--------------------------------------------------------------------------------
-- steps

function parse()
  local app = cargs.App:new("mekc", "The mekong compiler.", "<inputs>")

  app:add_option('flag', 'once', {'-'}, 'E', nil, 'Stop after preprocessing.')
  app:add_option('flag', 'once', {'-'}, 'A', nil, 'Stop after analysis.')
  app:add_option('flag', 'once', {'-'}, 'R', nil, 'Stop after rewriting.')
  app:add_option('flag', 'once', {'-'}, 'c', nil, 'Stop after compile and assembly.')

  app:add_option('joined_or_separate', 'once', {'-'}, 'o', nil, 'Write results to <file>.', '<file>')

  app:add_option('flag', 'last', {'-'}, 'g', nil, 'Include debug info.')

  app:add_option('joined', 'last', {'-'}, 'O', nil, 'Optimization level.')
  app:add_option('joinsep', 'list', {'-'}, 'D', nil, 'Define Macro.')
  app:add_option('joined', 'once', {'-', '--'}, 'std=', 'std', 'Language standard.', '<version>')

  app:add_option('joinsep', 'list', {'-'}, 'I', nil, 'Add directory to include search path.', '<dir>')
  app:add_option('joinsep', 'list', {'-'}, 'L', nil, 'Add directory to library search path.', '<dir>')
  app:add_option('joinsep', 'list', {'-'}, 'l', nil, 'Link .', '<library>')

  app:add_option('joined', 'once', {'--'}, 'cuda-gpu-arch=', 'cuda-gpu-arch',
    'GPU target architecture.', '<arch>')
  app:add_option('joined', 'once', {'--'}, 'cuda-path=', 'cuda-path',
    'CUDA SDK path.', '<dir>')

  app:add_option('flag', 'last', {'-', '--'}, 'ccc-print-phases', nil, "Print phases.")
  app:add_option('flag', 'last', {'-', '--'}, 'ccc-print-actions', nil, "Print actions.")
  app:add_option('flag', 'last', {'-', '--'}, 'ccc-print-bindings', nil, "Print bindings.")
  app:add_option('flag', 'last', {'-', '--'}, '###', nil, "Print (dont' execute) commands to run.")
  app:add_option('flag', 'last', {'-', '--'}, 'v', nil, "Print commands to run.")

  app:add_option('flag', 'last', {'-', '--'}, 'version', nil, "Print version information and exit.")

  -- parse arguments, gather OS info
  opts, arguments = app:parse()

  -- version + exit
  if opts['version'] then
    print("version: " .. VERSION)
    os.exit(0)
  end

  -- early exit if no inputs
  os_name = io.popen('uname -s','r'):read('*l')
  if #arguments == 0 then
    error('no input files')
  end

  -- defaults
  opts['cuda-path'] = opts['cuda-path'] or CUDAPATH
  opts['cuda-gpu-arch'] = opts['cuda-gpu-arch'] or CUDAARCH
  opts['O'] = opts['O'] or '2'

  opts['I'] = opts['I'] or {}
  opts['L'] = opts['L'] or {}
  opts['l'] = opts['l'] or {}
  opts['D'] = opts['D'] or {}

  -- convenience translations
  opts['g'] = opts['g'] and '-g'
  opts['std'] = opts['std'] and ('-std=' .. opts['std'])

  -- mutex constraints
  local actions = 0
  if opts['E'] then actions = actions + 1 end
  if opts['A'] then actions = actions + 1 end
  if opts['R'] then actions = actions + 1 end
  if opts['c'] then actions = actions + 1 end
  if actions > 1 then
    error('error: only one allowed of E, A, R, c')
  end
end

function phases()
  steps["analyze"]  = not opts["E"]
  steps["rewrite"]  = not opts["E"] and not opts["A"]
  steps["compile"]  = not opts["E"] and not opts["A"] and not opts["R"]
  steps["link"]     = not opts["E"] and not opts["A"] and not opts["R"] and not opts["c"]

  if opts["ccc-print-phases"] then
    print("1: preprocess")
    local i = 2
    for _, phase in ipairs{"analyze", "rewrite", "compile", "link"} do
      if steps[phase] then printf("%d: %s", i, phase) end
      i = i + 1
    end
  end
end

function pipeline()
  -- yaml files can be passed on the command line.
  -- Such an explicitly passed yaml file then disables analysis for the next
  -- cuda file and results in using the specified yaml file instead
  local yamls = {} -- yaml stack, enables behaviour above
  local objs = {}

  for _, input in pairs(arguments) do
    repeat -- lua hack for continue (breaks now act as a continue)
      local ftype = guess_type(input)

      if ftype == "unknown" then
        error(string.format("don't know what to do with file '%s'", input))
      end

      actions[#actions+1] = {"input", {input}, ftype}
      local result = #actions

      if ftype == "yaml" then
        yamls[#yamls+1] = result
        break
      end

      if ftype == "cuda" and steps["analyze"] and #yamls == 0 then
        local yaml, _ = add_action("analyze", {result})
        yamls[#yamls+1] = yaml
      end

      local popyaml = false
      if ftype == "cuda" and steps["rewrite"] then
        assert(#yamls > 0)
        result, ftype = add_action("rewrite", {result, yamls[#yamls]})
        popyaml = true
      end

      if ftype == "cuda2" and steps["compile"] then
        assert(#yamls > 0)
        result, ftype = add_action("compile", {result, yamls[#yamls]})
        popyaml = true
      end

      -- yaml stack only popped here, because it's need for both rewriting and compile
      if popyaml then
        yamls[#yamls] = nil
      end

      if ftype == "c++" and steps["compile"] then
        result, ftype = add_action("compile-c++", {result})
      end

      if ftype == "c" and steps["compile"] then
        result, ftype = add_action("compile-c", {result})
      end

      if ftype == "obj" and steps["link"] then
        objs[#objs+1] = result
      end
    until true
  end
  if #yamls > 0 and not opts["A"] then
    print("warning: there were unused yaml files")
  end

  if #objs > 0 then
    add_action("link", objs)
  end

  if opts["ccc-print-actions"] then
    for num, action in ipairs(actions) do
      printf("%d: %s, [%s], %s", num, action[1], table.concat(action[2], ", "), action[3] or "wat")
    end
  end
end

function bind()
  local action_outs = {}
  local base = ""

  -- lookup table to have correct "include origin" (#include "something.h")
  -- for rewritten .cua files, which (likely) have a different path than the
  -- original .cu file.
  -- maps: <full name of .cua> -> <directory of original cuda file>
  local original_paths = {}

  for idx, action in ipairs(actions) do
    -- collect inputs
    local inputs = {}
    for _, input in ipairs(action[2]) do
      inputs[#inputs+1] = action_outs[input]
    end

    local out = nil
    local include = nil

    -- if last step and user chose an output file, use that one.
    if idx == #actions and opts["o"] ~= nil then
      out = opts["o"]
    end

    -- only set 'out' to temporary if not already set above
    if action[1] == "input" then
      out = out or action[2][1]
      base = basename(out)
    elseif action[1] == "analyze" then
      out = out or make_temp(base .. "?.yaml")
    elseif action[1] == "rewrite" then
      out = out or make_temp(base .. "?.cua")
      -- Hackish, this point is where we know the original path of the input.
      -- We store it in the includes list so later binding steps can
      -- look it up.
      local input = action_outs[action[2][1]] -- .cu
      original_paths[out] = dirname(input)
    elseif action[1] == "compile" then
      out = out or make_temp(base .. "?.o")
      -- If we do not find an original path for a .cua, we know it has been
      -- rewritten in a different run of mekc. In this case we can't do
      -- anything and just issue a warning.
      local input = action_outs[action[2][1]]
      local dir = original_paths[input] -- .cua
      if not dir then
        print("warning: did not find iquote directory for " .. input)
      end
      include = dir
    elseif action[1] == "compile-c" then
      out = out or make_temp(base .. "?.o")
    elseif action[1] == "compile-c++" then
      out = out or make_temp(base .. "?.o")
    elseif action[1] == "link" then
      out = out or "a.out"
    end

    if action[1] ~= "input" and idx ~= #actions then
      temps[out] = true
    end
    action_outs[idx] = out


    bindings[idx] = {action[1], inputs, out, include = include}
  end

  if opts["ccc-print-bindings"] then
    for idx, binding in ipairs(bindings) do
      printf("%d: %s : [%s] -> %s", idx, binding[1], table.concat(binding[2], ", "), binding[3])
    end
  end
end

function has_flag(cmd, pat)
  for _, piece in ipairs(cmd) do
    if piece:match(pat) then
      return true
    end
  end
  return false
end

function translate_common_c(binding, cmd)
  cmd = append(cmd, opts["g"])
  cmd = append(cmd, opts["std"])
  for _, include in ipairs(opts["I"]) do
    cmd = append(cmd, "-I" .. include)
  end
  cmd = append(cmd, "-I" .. MERTINC)
  for _, define in ipairs(opts["D"]) do
    cmd = append(cmd, "-D" .. define)
  end
  if not has_flag(cmd, '^-O%d*$') then
    cmd = append(cmd, "-O"..opts["O"])
  end
  return cmd
end

function translate()
  for idx, binding in ipairs(bindings) do
    local cmd = nil
    if binding[1] == "input" then
      -- do nothing
    elseif binding[1] == "analyze" then
      cmd = { CLANGXX }
      cmd = append(cmd, "-fplugin=" .. LIBMEKONG)
      cmd = append(cmd, "-O1", "--cuda-device-only", "-mllvm", "-mekong-pre")
      cmd = append(cmd, "-mllvm", "-mekong-model=".. binding[3])

      cmd = append(cmd, "--cuda-gpu-arch=" .. opts["cuda-gpu-arch"])
      if opts["cuda-path"] ~= nil then
        cmd = append(cmd, "--cuda-path=" .. opts["cuda-path"])
      end
      cmd = translate_common_c(binding, cmd)
      cmd = append(cmd, binding[2][1])
      cmd = append(cmd, "-c", "-o", "/dev/null")

    elseif binding[1] == "rewrite" then
      cmd = { MEKONGRW }
      cmd = append(cmd, "--info=" .. binding[2][2])
      cmd = append(cmd, binding[2][1])
      cmd = append(cmd, "-o", binding[3])

    elseif binding[1] == "compile" then
      cmd = { CLANGXX }
      cmd = append(cmd, "-x", "cuda")
      if binding["include"] then
        cmd = append(cmd, "-iquote", binding["include"])
      end
      cmd = append(cmd, "-fplugin=" .. LIBMEKONG)
      cmd = append(cmd, "-mllvm", "-mekong", "-mllvm", "-mekong-model=".. binding[2][2])

      cmd = append(cmd, "--cuda-gpu-arch=" .. opts["cuda-gpu-arch"])
      if opts["cuda-path"] ~= nil then
        cmd = append(cmd, "--cuda-path=" .. opts["cuda-path"])
      end
      cmd = translate_common_c(binding, cmd)
      cmd = append(cmd, binding[2][1])
      cmd = append(cmd, "-c", "-o", binding[3])

    elseif binding[1] == "compile-c" then
      cmd = { CLANG }
      cmd = append(cmd, "-x", "c")
      cmd = append(cmd, "-fplugin=" .. LIBMEKONG)
      cmd = append(cmd, "-mllvm", "-mekong")
      cmd = translate_common_c(binding, cmd)
      cmd = append(cmd, binding[2][1])
      cmd = append(cmd, "-c", "-o", binding[3])

    elseif binding[1] == "compile-c++" then
      cmd = { CLANGXX }
      cmd = append(cmd, "-x", "c++")
      cmd = append(cmd, "-fplugin=" .. LIBMEKONG)
      cmd = append(cmd, "-mllvm", "-mekong")
      cmd = translate_common_c(binding, cmd)
      cmd = append(cmd, binding[2][1])
      cmd = append(cmd, "-c", "-o", binding[3])

    elseif binding[1] == "link" then
      cmd = { CLANGXX }
      cmd = append(cmd, "-O"..opts["O"])
      cmd = append(cmd, opts["g"])
      for _, libdir in ipairs(opts["L"]) do
        cmd = append(cmd, "-L" .. libdir)
      end
      for _, library in ipairs(opts["l"]) do
        cmd = append(cmd, "-l" .. library)
      end
      for _, input in ipairs(binding[2]) do
        cmd = append(cmd, input)
      end
      cmd = append(cmd, MERTLIB)
      cmd = append(cmd, "-lcudart_static", "-ldl", "-pthread")
      if os_name == "Linux" then
        cmd = append(cmd, "-lrt")
      end
      cmd = append(cmd, "-o", binding[3])

    end
    if cmd ~= nil then
      cmds[#cmds+1] = cmd
    end
  end
end

function execute()
  for _, cmd in ipairs(cmds) do
    cmd_str = table.concat(cmd, " ")
    if opts["v"] or opts["###"] then
      print(cmd_str)
    end
    if not opts["###"] then
      if os.execute(cmd_str) ~= 0 then
        error("error encountered, compilation aborted")
      end
    end
  end
end

function cleanup()
  for temp, _ in pairs(temps) do
    if opts["###"] then
      print("removing temporary " .. temp)
    else
      os.remove(temp)
    end
  end
end

--------------------------------------------------------------------------------
-- utility functions

function printf(fmt, ...)
  print(string.format(fmt, ...))
end

function escape(str)
  str = str:gsub('"', '\"')
  if str:find("[ \t]") then
    str = '"' .. str .. '"'
  end
  return str
end

function append(table, ...)
  local new = {}
  for _, el in ipairs(table) do
    new[#new+1] = el
  end
  for _, el in ipairs{...} do
    if el then
      new[#new+1] = escape(el)
    end
  end
  return new
end

function add_action(phase, input)
  local restype = phase_result(phase)
  actions[#actions+1] = {phase, input, restype}
  return #actions, restype
end

function phase_result(phase)
  if phase == "preprocess" then return "cuda" end
  if phase == "analyze" then return "yaml" end
  if phase == "rewrite" then return "cuda2" end
  if phase == "compile" then return "obj" end
  if phase == "compile-c" then return "obj" end
  if phase == "compile-c++" then return "obj" end
  if phase == "link" then return "binary" end
  assert(false, string.format("invalid phase '%s'", phase))
end

function has_suffix(str, ...)
  for _, suffix in ipairs({...}) do
    suffix = suffix:gsub("%.", "%%.")
    if str:lower():match(suffix .. "$") ~= nil then return true end
  end
  return false
end

-- guess type of input file
function guess_type(file)
  local steps = {true, true, true, true, true}
  if has_suffix(file, ".cu") then return "cuda" end
  if has_suffix(file, ".cua") then return "cuda2" end -- rewritten
  if has_suffix(file, ".yaml") then return "yaml" end
  if has_suffix(file, ".c") then return "c" end
  if has_suffix(file, ".cpp", ".cc", ".cxx") then return "c++" end
  if has_suffix(file, ".ll", ".bc", ".o") then return "obj" end
  return "unknown"
end

-- returns path stripped of directories and extension
function basename(path)
  return path:gsub("(.*/)(.*)", "%2"):gsub("(.*)(%..*)", "%1")
end

-- returns the directory part of path
function dirname(path)
  local dir = path:gsub("(.*)/.*", "%1")
  if path == dir then
    return "."
  else
    return dir
  end
end

function exists(path)
  local f=io.open(path, "r")
  if f~=nil then
    io.close(f)
    return true
  else
    return false
  end
end

local charset = {}  do -- [0-9a-zA-Z]
    for c = 48, 57  do table.insert(charset, string.char(c)) end
    for c = 65, 90  do table.insert(charset, string.char(c)) end
    for c = 97, 122 do table.insert(charset, string.char(c)) end
end

math.randomseed(os.clock()^5)
local function randomString(length)
    if not length or length <= 0 then return '' end
    return randomString(length - 1) .. charset[math.random(1, #charset)]
end

function make_temp(pattern)
  local attempt = pattern:gsub("?", "")
  while exists(attempt) or temps[attempt] ~= nil do
    attempt = pattern:gsub("?", "-" .. randomString(3))
  end
  return attempt
end

--------------------------------------------------------------------------------

function main()
  parse()
  phases()
  pipeline()
  bind()
  translate()
  execute()
end

ok, err = pcall(main)
cleanup()
if not ok then
  print(err)
  os.exit(1)
end
